from ASVBaselineTool.ENSATTACK import ENS
import Config.attackconfig as Config
import Config.globalconfig as Gconfig
import torch
import yaml
from LFCC_RESNET.model import Model_zoo as LFCC_RESNET_MODEL
from MFCC_RESNET.model import Model_zoo as MFCC_RESNET_MODEL
from SPEC_RESNET_NEW.model import Model_zoo as SPEC_RESNET_MODEL
from RawNet.model import RawNet as RAWNET_MODEL


labelPath = Config.TLABEL
datasetPath = Config.TPATH
AdvsetPath = Config.ENS64600_AdvsetPath
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

#配置RAWNET64600
RAWNET_model_path = Config.RawNet_model
yaml_path = Gconfig.RAW_YAML_CONFIG_PATH
with open(yaml_path, 'r') as f_yaml:
    parser1 = yaml.load(f_yaml, Loader=yaml.FullLoader)
model0 = RAWNET_MODEL(parser1['model'], device)
model0 = (model0).to(device)
model0.load_state_dict(torch.load(RAWNET_model_path, map_location=device))
print('RAWNET 64600 Model loaded : {}'.format(RAWNET_model_path))

#配置MFCC_RESNET64600
MFCC_RESNET_model_path = Config.MFCC_RESNET_model
model1 = MFCC_RESNET_MODEL("ResNet18")
model1 = (model1).to(device)
model1.load_state_dict(torch.load(MFCC_RESNET_model_path, map_location=device))
print('MFCC RESNET 64600 Model loaded : {}'.format(MFCC_RESNET_model_path))

#配置LFCC_RESNET64600
LFCC_RESNET_model_path = Config.LFCC_RESNET_model
model2 = LFCC_RESNET_MODEL("ResNet18")
model2 = (model2).to(device)
model2.load_state_dict(torch.load(LFCC_RESNET_model_path, map_location=device))
print('LFCC RESNET 64600 Model loaded : {}'.format(LFCC_RESNET_model_path))


#配置SPEC_RESNET64600
SPEC_RESNET_model_path = Config.SPEC_RESNET_NEW_model
model3 = SPEC_RESNET_MODEL("ResNet18")
model3 = (model3).to(device)
model3.load_state_dict(torch.load(SPEC_RESNET_model_path, map_location=device))
print('SPEC RESNET 64600 Model loaded : {}'.format(SPEC_RESNET_model_path))

ensattack=ENS(model0,model1,model2,model3,labelPath,datasetPath,AdvsetPath,steps=60,eps=0.1)
ensattack.Attack(alpha=0.0002,stop=0.001)